Toxic Comment Filter

BiLSTM model for multi label classification
code
Deep Learning
Python, R
Author

Simone Brazzi

Published

August 12, 2024

1 Introduction

Build a model that can filter user comments based on the degree of language maliciousness:

  • Preprocess the text by eliminating the set of tokens that do not make significant contribution at the semantic level.
  • Transform the text corpus into sequences.
  • Build a Deep Learning model including recurrent layers for a multilabel classification task.
  • At prediction time, the model should return a vector containing a 1 or a 0 at each label in the dataset (toxic, severe_toxic, obscene, threat, insult, identity_hate). In this way, a non-harmful comment will be classified by a vector of only 0s [0,0,0,0,0]. In contrast, a dangerous comment will exhibit at least a 1 among the 6 labels.

2 Setup

Leveraging Quarto and RStudio, I will setup an R and Python enviroment.

2.1 Import R libraries

Import R libraries. These will be used for both the rendering of the document and data analysis. The reason is I prefer ggplot2 over matplotlib. I will also use colorblind safe palettes.

Code
library(tidyverse, verbose = FALSE)
library(tidymodels, verbose = FALSE)
library(reticulate)
library(ggplot2)
library(plotly)
library(RColorBrewer)
library(bslib)
library(Metrics)
library(gt)

reticulate::use_virtualenv("r-tf")

2.2 Import Python packages

Code
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras_nlp

from keras.backend import clear_session
from keras.models import Model, load_model
from keras.layers import TextVectorization, Input, Dense, Embedding, Dropout, GlobalAveragePooling1D, LSTM, Bidirectional, GlobalMaxPool1D, Flatten, Attention
from keras.metrics import Precision, Recall, AUC, SensitivityAtSpecificity, SpecificityAtSensitivity, F1Score


from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import multilabel_confusion_matrix, classification_report, ConfusionMatrixDisplay, precision_recall_curve, f1_score, recall_score, roc_auc_score

Create a Config class to store all the useful parameters for the model and for the project.

2.3 Class Config

I created a class with all the basic configuration of the model, to improve the readability.

Code
class Config():
    def __init__(self):
        self.url = "https://s3.eu-west-3.amazonaws.com/profession.ai/datasets/Filter_Toxic_Comments_dataset.csv"
        self.max_tokens = 20000
        self.output_sequence_length = 911 # check the analysis done to establish this value
        self.embedding_dim = 128
        self.batch_size = 32
        self.epochs = 100
        self.temp_split = 0.3
        self.test_split = 0.5
        self.random_state = 42
        self.total_samples = 159571 # total train samples
        self.train_samples = 111699
        self.val_samples = 23936
        self.features = 'comment_text'
        self.labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
        self.new_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', "clean"]
        self.label_mapping = {label: i for i, label in enumerate(self.labels)}
        self.new_label_mapping = {label: i for i, label in enumerate(self.labels)}
        self.path = "/Users/simonebrazzi/R/blog/posts/toxic_comment_filter/history/f1score/"
        self.model =  self.path + "model_f1.keras"
        self.checkpoint = self.path + "checkpoint.lstm_model_f1.keras"
        self.history = self.path + "lstm_model_f1.xlsx"
        
        self.metrics = [
            Precision(name='precision'),
            Recall(name='recall'),
            AUC(name='auc', multi_label=True, num_labels=len(self.labels)),
            F1Score(name="f1", average="macro")
            
        ]
    def get_early_stopping(self):
        early_stopping = keras.callbacks.EarlyStopping(
            monitor="val_f1", # "val_recall",
            min_delta=0.2,
            patience=10,
            verbose=0,
            mode="max",
            restore_best_weights=True,
            start_from_epoch=3
        )
        return early_stopping

    def get_model_checkpoint(self, filepath):
        model_checkpoint = keras.callbacks.ModelCheckpoint(
            filepath=filepath,
            monitor="val_f1", # "val_recall",
            verbose=0,
            save_best_only=True,
            save_weights_only=False,
            mode="max",
            save_freq="epoch"
        )
        return model_checkpoint

    def find_optimal_threshold_cv(self, ytrue, yproba, metric, thresholds=np.arange(.05, .35, .05), n_splits=7):

      # instantiate KFold
      kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
      threshold_scores = []

      for threshold in thresholds:

        cv_scores = []
        for train_index, val_index in kf.split(ytrue):

          ytrue_val = ytrue[val_index]
          yproba_val = yproba[val_index]

          ypred_val = (yproba_val >= threshold).astype(int)
          score = metric(ytrue_val, ypred_val, average="macro")
          cv_scores.append(score)

        mean_score = np.mean(cv_scores)
        threshold_scores.append((threshold, mean_score))

        # Find the threshold with the highest mean score
        best_threshold, best_score = max(threshold_scores, key=lambda x: x[1])
      return best_threshold, best_score

config = Config()

3 Data

The dataset is accessible using tf.keras.utils.get_file to get the file from the url. N.B. For reproducibility purpose, I also downloaded the dataset. There was time in which the link was not available.

Code
# df = pd.read_csv(config.path)
file = tf.keras.utils.get_file("Filter_Toxic_Comments_dataset.csv", config.url)
df = pd.read_csv(file)
Code
library(reticulate)

py$df %>%
  tibble() %>% 
  head(5) %>% 
  gt() %>% 
  tab_header(
    title = "First five observations"
  ) %>% 
   cols_align(
    align = "center",
    columns = c("toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate", "sum_injurious")
  ) %>% 
  cols_align(
    align = "left",
    columns = comment_text
  ) %>% 
  cols_label(
    comment_text = "Comments",
    toxic = "Toxic",
    severe_toxic = "Severe Toxic",
    obscene = "Obscene",
    threat = "Threat",
    insult = "Insult",
    identity_hate = "Identity Hate",
    sum_injurious = "Sum Injurious"
    )
Table 1: First 5 elemtns
First five observations
Comments Toxic Severe Toxic Obscene Threat Insult Identity Hate Sum Injurious
Explanation Why the edits made under my username Hardcore Metallica Fan were reverted? They weren't vandalisms, just closure on some GAs after I voted at New York Dolls FAC. And please don't remove the template from the talk page since I'm retired now.89.205.38.27 0 0 0 0 0 0 0
D'aww! He matches this background colour I'm seemingly stuck with. Thanks. (talk) 21:51, January 11, 2016 (UTC) 0 0 0 0 0 0 0
Hey man, I'm really not trying to edit war. It's just that this guy is constantly removing relevant information and talking to me through edits instead of my talk page. He seems to care more about the formatting than the actual info. 0 0 0 0 0 0 0
" More I can't make any real suggestions on improvement - I wondered if the section statistics should be later on, or a subsection of ""types of accidents"" -I think the references may need tidying so that they are all in the exact same format ie date format etc. I can do that later on, if no-one else does first - if you have any preferences for formatting style on references or want to do it yourself please let me know. There appears to be a backlog on articles for review so I guess there may be a delay until a reviewer turns up. It's listed in the relevant form eg Wikipedia:Good_article_nominations#Transport " 0 0 0 0 0 0 0
You, sir, are my hero. Any chance you remember what page that's on? 0 0 0 0 0 0 0

Lets create a clean variable for EDA purpose: I want to visually see how many observation are clean vs the others labels.

Code
df.loc[df.sum_injurious == 0, "clean"] = 1
df.loc[df.sum_injurious != 0, "clean"] = 0

3.1 EDA

First a check on the dataset to find possible missing values and imbalances.

3.1.1 Frequency

Code
library(reticulate)
df_r <- py$df
new_labels_r <- py$config$new_labels

df_r_grouped <- df_r %>% 
  select(all_of(new_labels_r)) %>%
  pivot_longer(
    cols = all_of(new_labels_r),
    names_to = "label",
    values_to = "value"
  ) %>% 
  group_by(label) %>%
  summarise(count = sum(value)) %>% 
  mutate(freq = round(count / sum(count), 4))

df_r_grouped %>% 
  gt() %>% 
  tab_header(
    title = "Labels frequency",
    subtitle = "Absolute and relative frequency"
  ) %>% 
  fmt_number(
    columns = "count",
    drop_trailing_zeros = TRUE,
    drop_trailing_dec_mark = TRUE,
    use_seps = TRUE
  ) %>% 
  fmt_percent(
    columns = "freq",
    decimals = 2,
    drop_trailing_zeros = TRUE,
    drop_trailing_dec_mark = FALSE
  ) %>% 
  cols_align(
    align = "center",
    columns = c("count", "freq")
  ) %>% 
  cols_align(
    align = "left",
    columns = label
  ) %>% 
  cols_label(
    label = "Label",
    count = "Absolute Frequency",
    freq = "Relative frequency"
  )
Table 2: Absolute and relative labels frequency
Labels frequency
Absolute and relative frequency
Label Absolute Frequency Relative frequency
clean 143,346 80.33%
identity_hate 1,405 0.79%
insult 7,877 4.41%
obscene 8,449 4.73%
severe_toxic 1,595 0.89%
threat 478 0.27%
toxic 15,294 8.57%

3.1.2 Barchart

Code
library(reticulate)
barchart <- df_r_grouped %>%
  ggplot(aes(x = reorder(label, count), y = count, fill = label)) +
  geom_col() +
  labs(
    x = "Labels",
    y = "Count"
  ) +
  # sort bars in descending order
  scale_x_discrete(limits = df_r_grouped$label[order(df_r_grouped$count, decreasing = TRUE)]) +
  scale_fill_brewer(type = "seq", palette = "RdYlBu") +
  theme_minimal()
ggplotly(barchart)
Figure 1: Imbalance in the dataset with clean variable

It is visible how much the dataset in imbalanced. This means it could be useful to check for the class weight and use this argument during the training.

It is clear that most of our text are clean. We are talking about 0.8033 of the observations which are clean. Only 0.1967 are toxic comments.

3.2 Sequence lenght definition

To convert the text in a useful input for a NN, it is necessary to use a TextVectorization layer. See the Section 4 section.

One of the method is output_sequence_length: to better define it, it is useful to analyze our text length. To simulate what the model we do, we are going to remove the punctuation and the new lines from the comments.

3.2.1 Summary

Code
library(reticulate)
df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
    ) %>% 
  pull(text_length) %>% 
  summary() %>% 
  as.list() %>% 
  as_tibble() %>% 
  gt() %>% 
  tab_header(
    title = "Summary Statistics",
    subtitle = "of text length"
  ) %>% 
  fmt_number(
    drop_trailing_zeros = TRUE,
    drop_trailing_dec_mark = TRUE,
    use_seps = TRUE
  ) %>% 
  cols_align(
    align = "center",
  ) %>% 
  cols_label(
    Min. = "Min",
    `1st Qu.` = "Q1",
    Median = "Median",
    `3rd Qu.` = "Q3",
    Max. = "Max"
  )
Table 3: Summary of text length
Summary Statistics
of text length
Min Q1 Median Mean Q3 Max
4 91 196 378.4 419 5,000

3.2.2 Boxplot

Code
library(reticulate)
boxplot <- df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
    ) %>% 
  # pull(text_length) %>% 
  ggplot(aes(y = text_length)) +
  geom_boxplot() +
  theme_minimal()
ggplotly(boxplot)
Figure 2: Text length boxplot

3.2.3 Histogram

Code
library(reticulate)
df_ <- df_r %>% 
  mutate(
    comment_text_clean = comment_text %>%
      tolower() %>% 
      str_remove_all("[[:punct:]]") %>% 
      str_replace_all("\n", " "),
    text_length = comment_text_clean %>% str_count()
  )

Q1 <- quantile(df_$text_length, 0.25)
Q3 <- quantile(df_$text_length, 0.75)
IQR <- Q3 - Q1
upper_fence <- as.integer(Q3 + 1.5 * IQR)

histogram <- df_ %>% 
  ggplot(aes(x = text_length)) +
  geom_histogram(bins = 50) +
  geom_vline(aes(xintercept = upper_fence), color = "red", linetype = "dashed", linewidth = 1) +
  theme_minimal() +
  xlab("Text Length") +
  ylab("Frequency") +
  xlim(0, max(df_$text_length, upper_fence))
ggplotly(histogram)
Figure 3: Text length histogram with boxplot upper fence

Considering all the above analysis, I think a good starting value for the output_sequence_length is 911, the upper fence of the boxplot. In the last plot, it is the dashed red vertical line.. Doing so, we are removing the outliers, which are a small part of our dataset.

3.3 Dataset

Now we can split the dataset in 3: train, test and validation sets. Considering there is not a function in sklearn which lets split in these 3 sets, we can do the following: - split between a train and temporary set with a 0.3 split. - split the temporary set in 2 equal sized test and val sets.

Code
x = df[config.features].values
y = df[config.labels].values

xtrain, xtemp, ytrain, ytemp = train_test_split(
  x,
  y,
  test_size=config.temp_split, # .3
  random_state=config.random_state
  )
xtest, xval, ytest, yval = train_test_split(
  xtemp,
  ytemp,
  test_size=config.test_split, # .5
  random_state=config.random_state
  )

xtrain shape: py$xtrain.shape ytrain shape: py$ytrain.shape xtest shape: py$xtest.shape ytest shape: py$ytest.shape xval shape: py$xval.shape yval shape: py$yval.shape

The datasets are created using the tf.data.Dataset function. It creates a data input pipeline. The tf.data API makes it possible to handle large amounts of data, read from different data formats, and perform complex transformations. The tf.data.Dataset is an abstraction that represents a sequence of elements, in which each element consists of one or more components. Here each dataset is creates using from_tensor_slices. It create a tf.data.Dataset from a tuple (features, labels). .batch let us work in batches to improve performance, while .prefetch overlaps the preprocessing and model execution of a training step. While the model is executing training step s, the input pipeline is reading the data for step s+1. Check the documentation for further informations.

Code
train_ds = (
    tf.data.Dataset
    .from_tensor_slices((xtrain, ytrain))
    .shuffle(xtrain.shape[0])
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

test_ds = (
    tf.data.Dataset
    .from_tensor_slices((xtest, ytest))
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

val_ds = (
    tf.data.Dataset
    .from_tensor_slices((xval, yval))
    .batch(config.batch_size)
    .prefetch(tf.data.experimental.AUTOTUNE)
)
Code
print(
  f"train_ds cardinality: {train_ds.cardinality()}\n",
  f"val_ds cardinality: {val_ds.cardinality()}\n",
  f"test_ds cardinality: {test_ds.cardinality()}\n"
  )
train_ds cardinality: 3491
 val_ds cardinality: 748
 test_ds cardinality: 748

Check the first element of the dataset to be sure that the preprocessing is done correctly.

Code
train_ds.as_numpy_iterator().next()
(array([b'Neither of those edits were vandalism and they should not have been reverted by you. Both articles should be deleted and are ridiculous entries in wikipedia. Also please have enough respect to sign your comments to my talk page. 24.235.129.212',
       b'"\nHaha... you\'re welcome. I briefly contemplated the irony of that vandalism placement too, but I also appreciated the converse irony of blocking the vandal on the edit right placed right above ""until you get banned"". Thanks for getting the vandalism on my page too. Best,   "',
       b'"\n\n ""Notable individuals associated with Lambeth"" \n\nI recently added David Bowie, John Major, Ted Rogers and Roger Moore to the ""Notable individuals associated with Lambeth"" section, all of whom were born and/or grew up in Lambeth, and they were removed almost immediately. And I\'m very very curious to know why! God forbid it was out of snobbery - are only poets and composers allowed?!"',
       b'"\n\nI have no concerns about his edit, but it is unsourced. I trust this guy, but we may have to wait until his Encyclopedia about Galliformes comes out and totally revises the taxonomy. Pinudjem is his second part of his middle name (Amoun-Pinudjem) of his alternate name (Milad Sourial). If you search for that name on the web you can see that he is making a movie called ""Goddess of the Sun"", not only from the photo gallery concerning the Green Peafowl but also from a site not by him.\n\nKermit does not have a real name. His real parents gave him a name he didn\'t like. Sadly, his real parents died in a political assacination. He was then named Kermit Roosevelt by his adoptive parents, but he didn\'t like that either. He has three user accounts: User:Amoun-Pinudgem (middle name), User:Milad A.P. Sourial, and his newest is User:Pinudjem. However, he has also edited with an IP once, which was reverted. He often writes under Kermit Blackwood and was also quoted in the Red Data Book as K. B. Woods (who suggested that Greens are monogamous and possible Yunnan subspecies).\n\nHe is a writer on the Feathersite, a site that talks about poultry, wild ducks and gamebirds, etc.\n\nOne must know that Kermit is not sure that the Green Peafowl is a superspecies complex, but that he created that hypothesis, and that he believes the best way to know the truth is to try and disprove it. However, preliminary data confirms his hypothesis of multiple species. He has already collected tissue samples and compared them but he believes it is better if nuclear DNA can also be extracted from samples.\n\nWolfgang Mennig, a breeder who suggests that there are additional subspecies (not species), also confirms (he e-mailed me) that he is friends with Kermit and that they try to work together as much as possible.  "',
       b'"\nThanks Flyer22 that was not my guess.  Here\'s what the NBC website says, the official site, not a message board: Lucas was injured when a beam fell on his legs and EJ agreed to save Lucas only if Sami agrees to let EJ have his way with her. Sami reluctantly had sex with EJ and helped him escape to Mexico, after he was wanted in connection with the shooting of John Black. Despite Sami\xe2\x80\x99s attempt to turn over a new leaf, EJ returned and pressed her for information on police activities, and she was pulled into another web of deception, trying to prevent him from learning she was pregnant. That\'s from Sami Brady.  That tells me the show\'s official position is that she agreed.  What a character says and what the show\'s position is in this case are not the same.  Since Ed Scott took over, only Lucas has used the word and Sami and EJ have shared a couple of kisses as well as reading the Santo and Colleen letters together. It has been indicated they are going to ""tighten things up"" and fill in some wholes. The anticipation by some fans (I know, there\'s that ""some"" word but that is the case with people I have encountered) is that the December 29th issue will be resolved soon.   "',
       b'"\n Your submission at Articles for creation \n Ratikant Kanungo, which you submitted to Articles for creation, has been created. The article has been assessed as Start-Class, which is recorded on the article\'s talk page. You may like to take a look at the grading scheme to see how you can improve the article.\nYou are more than welcome to continue making quality contributions to Wikipedia. .\n If you have any questions, you are welcome to ask at the help desk.\n If you would like to help us improve this process, please consider .\nThank you for helping improve Wikipedia!\n  "',
       b'Archived talk \n\nNo idea about pork rib tea, but I find it very different form pho due to dang gui. Then again, the cuisine in the region is a bit mixed so, maybe?\n\nShould I archive my talk page?',
       b'Re Stalin Article \n\nHi:\nYour remarks hit the nail on the head. The Stalin article is a typical example of the left wing anti-American bias that is considered NPV by the folks at Wikipedia. I refuse to contribute as long as such nonsense is tolerated.\nBerndd11222',
       b'Sleazebag \n\nJimbo, you are a sleazebag. Please go away.',
       b'The latter of the two comments above constitutes a personal attack. Please do not make personal attacks towards another Wikipedia user, even if their opinion differs from yours.',
       b'"\n\nBecause your change is incorrect it is not known in French as ""Conseil Europ\xc3\xa9en pour la Recherche Nucl\xc3\xa9aire"" and hasn\'t been for quite a number of years in French it is ""Organisation europ\xc3\xa9enne pour la recherche nucl\xc3\xa9aire"". See here. cheers kri "',
       b"I don't understand what you're talking about. It was not me that changed the title (see Discussion page). It was already Equestrian order when you looked at it. What missing informnation are you referring to?",
       b'New submission of China Energy Fund Committee (CEFC) by SUNREST \n\nDear, this is my new submission of CEFC\n\n China Energy Fund Committee (CEFC) \n      \nChina Energy Fund Committee (CEFC) is a non-governmental, non-profit civil society organization. The Committee is an NGO with Special Consultative Status, the United Nations Economic and Social Council (UN ECOSOC). Registered in Hong Kong, the Committee obtains tax exemption under Section 88 of the Inland Revenue Ordinance as a charitable organization. Also registered in Virginia, the United States, the Committee obtains tax exemption under Section 501(c)(3) of the Internal Revenue Code as a public charity.     \xe2\x86\x93',
       b"re:  RfA Suggestion \n\nWith all due respect, I'd prefer not to withdraw it just yet.  While I recognize that I am certainly not off to anywhere near a good start, I'd like to let it run at least 24-48 hours to - if nothing else - draw a good amount of feedback from the community.  Now, if it becomes a WP:SNOW situation, I'll be more than willing to withdraw early, but part of why I'm doing this now is to gather some honest feedback from the community about where I stand and where I need to improve, if not successfully passing.  Regards,",
       b'i think dapi misquoted glantz so i want to read the sources. i this false? and fact is that german units were back soon an started pounding russian units so i think dapi highlighted something which is not as cool as he thinks. value for the reader....',
       b"RV \nSorry about the RV, went to check on this topic today and saw that someone? erased the page under my username. I'll change my password immediatly.",
       b'I couldnt care less what you say, fuck off from my talk page 31.209.16.177',
       b'Nowhere in the links you provided do I see that MPAA and BBFC classifications are non-notable. They are notable enough to have Wikipedia articles devoted to them. The information is displayed prominently at the beginning of almost all previews and in advertising for movies because of its notability. The information includes the UK and the US, the two largest English-speaking countries. How can you assert there is a Worldview bias?',
       b'Please stop.  If you continue to vandalize pages, you will be blocked from editing Wikipedia. \xe2\x80\x93 (t/c)',
       b"Re Madrid Links \n\nThank you very much for your message.\n\nI only added the link in the useful links section because I believed that it was relevant to the article in question. It is not an advert, a way for me to make any money or an attempt to trick the originally collective spirit of Wikipedia. \n\nObviously I do not understand Wikipedia, as it seems to be the preserved playground of the few. So who exactly polices the police here?\n\nThanks for taking the time to message me, but I don't think that I will take the time to contribute again.",
       b'"\n\n Song move \n\nYou gave me some good advice about moving and redirecting pages. Thank you.\n\nUnfortunately, I made a real mess when I tried to move ""Let It Snow"" to ""Let It Snow! Let It Snow! Let It Snow!""  When I did the redirect, I accidentally mistyped the title as ""Let It Snow! Let It Snoe! Let It Snow!"" (notice the ""Snoe"") and forgot to look before I leapt. If that wasn\'t bad enough, I then went to the Redirect page for ""Let It Snow"" and, instead of leaving the misspelling alone, I changed it to redirect to ""Let It Snow! Let It Snow! Let It Snow!"" (the correct spelling), thus creating that page, which did not previously exist. Thus, when I tried to change the misspelled name to the correct name, I of course got the error that the other page already exists.  \n\nNot to mention that there were some double redirects when I looked in ""What links here,"" which I was going to start to fix until I realized how badly I messed up the original move. (It turns out that someone had renamed the page in 2005, with no discussion, from ""Let It Snow, Let It Snow, Let It Snow"" to ""Let It Snow,"" a move I would have vehemently protested if I\'d been hanging around here at the time.) \n\nI need some help to undo this. The simplest thing would be to delete the current ""redirect"" page for the correct title and resume from there, renaming my misspelled page and then looking at the redirects again, but I can\'t delete pages, and I know the rules about cutting and pasting.  "',
       b'Bgwhite, all pages with a Surname Clarification Template need a common. Is that right?',
       b"I have no desire to edit anything, those who put up nonsense should be the ones removing it.  The reality is as it stands, irreligion is associated ONLY with Christianity.  This is a SERIOUS POV issue and those who put the portal link need to address this question.  It does NOT logically follow that one group of Christians' views should lead to this portal link being placed in the talk page to the exclusion of all other religions.  It is frankly speaking an attempt to frame the article in a manner that targets one religion, and that is not acceptable at all.",
       b"I live in same country as Xherdan Shaqiri takes to me less than 30 min to go to his hometown so don't tell me to where he lived in Yugoslavia, is done there is no Yugoslavia no more stop living in the past we have future ahead us. And one thing you can change the name in Wiki where he was born but not the FACT <3\n  22:03, 5 July 2015 (GMT)",
       b'LOL \nlol, not what I expected when I typed pair of pants into my wiki search bar',
       b'February 2007 (UTC)\nAh, thanks! Response on my talk page. \xe2\x80\x94 [T@lk/Improve me] 04:38, 19',
       b"The article's looking better and better!  One other idea for an article, if you're still searching: there's a lot out there on Eddie Kinner.  talk",
       b"=]\n\nWE ARE GOING TO MAKE SURE YOU ARE EXPOSED!! DROP YOUR RELATIONS WITH ANTI AMERICAN HATE GROUPS\n''WE ARE GOING TO MAKE SURE YOU ARE EXPOSED!! DROP YOUR RELATIONS WITH ANTI AMERICAN HATE GROUPS",
       b"List of Shania Twain music videos \n\n List of Shania Twain music videos \n\nHi, as you know, we both have been editing articles related to Shania Twain. There's an article, List of Shania Twain music videos, which is nominated as a featured list. If you would like, you can support it at Wikipedia:Featured list candidates. All you have to do is put Support followed by your signature using ~~~~. Thanks!",
       b'(sorry Sinebot - lol, my little bot buddy)',
       b'"\nChanging to another username: Superkidd. It\'s the name I use in most video games. "',
       b"do we discredit all sources which have made errors, because if so, there's a certain frequently-cited columnist we will need to purge. And anyway, regardless of his errors reading Wikipedia output, the thrust of his claims had merit, and"],
      dtype=object), array([[0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]]))

And we check also the shape. We expect a feature of shape (batch, ) and a target of shape (batch, number of labels).

Code
print(
  f"text train shape: {train_ds.as_numpy_iterator().next()[0].shape}\n",
  f" text train type: {train_ds.as_numpy_iterator().next()[0].dtype}\n",
  f"label train shape: {train_ds.as_numpy_iterator().next()[1].shape}\n",
  f"label train type: {train_ds.as_numpy_iterator().next()[1].dtype}\n"
  )
text train shape: (32,)
  text train type: object
 label train shape: (32, 6)
 label train type: int64

4 Preprocessing

Of course preprocessing! Text is not the type of input a NN can handle. The TextVectorization layer is meant to handle natural language inputs. The processing of each example contains the following steps: 1. Standardize each example (usually lowercasing + punctuation stripping) 2. Split each example into substrings (usually words) 3. Recombine substrings into tokens (usually ngrams) 4. Index tokens (associate a unique int value with each token) 5. Transform each example using this index, either into a vector of ints or a dense float vector.

For more reference, see the documentation at the following link.

Code
text_vectorization = TextVectorization(
  max_tokens=config.max_tokens,
  standardize="lower_and_strip_punctuation",
  split="whitespace",
  output_mode="int",
  output_sequence_length=config.output_sequence_length,
  pad_to_max_tokens=True
  )

# prepare a dataset that only yields raw text inputs (no labels)
text_train_ds = train_ds.map(lambda x, y: x)
# adapt the text vectorization layer to the text data to index the dataset vocabulary
text_vectorization.adapt(text_train_ds)

This layer is set to: - max_tokens: 20000. It is common for text classification. It is the maximum size of the vocabulary for this layer. - output_sequence_length: 911. See Figure 3 for the reason why. Only valid in "int" mode. - output_mode: outputs integer indices, one integer index per split string token. When output_mode == “int”, 0 is reserved for masked locations; this reduces the vocab size to max_tokens - 2 instead of max_tokens - 1. - standardize: "lower_and_strip_punctuation". - split: on whitespace.

To preserve the original comments as text and also have a tf.data.Dataset in which the text is preprocessed by the TextVectorization function, it is possible to map it to the features of each dataset.

Code
processed_train_ds = train_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_val_ds = val_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_test_ds = test_ds.map(
    lambda x, y: (text_vectorization(x), y),
    num_parallel_calls=tf.data.experimental.AUTOTUNE
)

5 Model

5.1 Definition

Define the model using the Functional API.

Code
def get_deeper_lstm_model():
    clear_session()
    inputs = Input(shape=(None,), dtype=tf.int64, name="inputs")
    embedding = Embedding(
        input_dim=config.max_tokens,
        output_dim=config.embedding_dim,
        mask_zero=True,
        name="embedding"
    )(inputs)
    x = Bidirectional(LSTM(256, return_sequences=True, name="bilstm_1"))(embedding)
    x = Bidirectional(LSTM(128, return_sequences=True, name="bilstm_2"))(x)
    # Global average pooling
    x = GlobalAveragePooling1D()(x)
    # Add regularization
    x = Dropout(0.3)(x)
    x = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
    x = LayerNormalization()(x)
    outputs = Dense(len(config.labels), activation='sigmoid', name="outputs")(x)
    model = Model(inputs, outputs)
    model.compile(optimizer='adam', loss="binary_crossentropy", metrics=config.metrics, steps_per_execution=32)
    
    return model

lstm_model = get_deeper_lstm_model()
lstm_model.summary()

5.2 Callbacks

Finally, the model has been trained using 2 callbacks: - Early Stopping, to avoid to consume the kaggle GPU time. - Model Checkpoint, to retrieve the best model training information.

Code
my_es = config.get_early_stopping()
my_mc = config.get_model_checkpoint(filepath="/checkpoint.keras")
callbacks = [my_es, my_mc]

5.3 Final preparation before fit

Considering the dataset is imbalanced, to increase the performance we need to calculate the class weight. This will be passed during the training of the model.

Code
lab = pd.DataFrame(columns=config.labels, data=ytrain)
r = lab.sum() / len(ytrain)
class_weight = dict(zip(range(len(config.labels)), r))
df_class_weight = pd.DataFrame.from_dict(
  data=class_weight,
  orient='index',
  columns=['class_weight']
  )
df_class_weight.index = config.labels
Code
library(reticulate)
py$df_class_weight %>% 
  gt() %>% 
  fmt_percent(
    decimals = 2,
    drop_trailing_zeros = TRUE,
    drop_trailing_dec_mark = TRUE
  )
Table 4: Class weight
class_weight
9.59%
0.99%
5.28%
0.31%
4.91%
0.87%

It is also useful to define the steps per epoch for train and validation dataset. This step is required to avoid to not consume entirely the dataset during the fit, which happened to me.

Code
steps_per_epoch = config.train_samples // config.batch_size
validation_steps = config.val_samples // config.batch_size

5.4 Fit

The fit has been done on Kaggle to levarage the GPU. Some considerations about the model:

  • .repeat() ensure the model sees all the dataset.
  • epocs is set to 100.
  • validation_data has the same repeat.
  • callbacks are the one defined before.
  • class_weight ensure the model is trained using the frequency of each class, because our dataset is imbalanced.
  • steps_per_epoch and validation_steps depend on the use of repeat.
Code
history = model.fit(
  processed_train_ds.repeat(),
  epochs=config.epochs,
  validation_data=processed_val_ds.repeat(),
  callbacks=callbacks,
  class_weight=class_weight,
  steps_per_epoch=steps_per_epoch,
  validation_steps=validation_steps
  )

Now we can import the model and the history trained on Kaggle.

Code
model = load_model(filepath=config.model)
history = pd.read_excel(config.history)

5.5 Evaluate

Code
validation = model.evaluate(
  processed_val_ds.repeat(),
  steps=validation_steps, # 748
  verbose=0
  )
Code
val_metrics <- tibble(
  metric = c("loss", "precision", "recall", "auc", "f1_score"),
  value = py$validation
  )
val_metrics %>% 
  gt() %>% 
  fmt_number(
    columns = c("value"),
    decimals = 4,
    drop_trailing_zeros = TRUE,
    drop_trailing_dec_mark = TRUE
  ) %>% 
  cols_align(
    align = "left",
    columns = metric
  ) %>% 
  cols_align(
    align = "center",
    columns = value
  ) %>% 
  cols_label(
    metric = "Metric",
    value = "Value"
  )
Table 5: Model validation metric
Metric Value
loss 0.0542
precision 0.7888
recall 0.671
auc 0.9572
f1_score 0.0293

5.6 Predict

For the prediction, the model does not need to repeat the dataset, because it has already been trained on all of the train data. Now it has just to consume the new data to make the prediction.

Code
predictions = model.predict(processed_test_ds, verbose=0)

5.7 Confusion Matrix

The best way to assess the performance of a multi label classification is using a confusion matrix. Sklearn has a specific function to create a multi label classification matrix to handle the fact that there could be multiple labels for one prediction.

5.7.1 Grid Search Cross Validation for best threshold

Grid Search CV is a technique for fine-tuning hyperparameter of a ML model. It systematically search through a set of hyperparamenter values to find the combination which led to the best model performance. In this case, I am using a KFold Cross Validation is a resempling technique to split the data into k consecutive folds. Each fold is used once as a validation while the k - 1 remaining folds are the training set. See the documentation for more information.

The model is trained to optimize the recall. The decision was made because the cost of missing a True Positive is greater than a False Positive. In this case, missing a injurious observation is worst than classifying a clean one as bad.

5.7.2 Confidence threshold and Precision-Recall trade off

Whilst the KFold GDCV technique is usefull to test multiple hyperparameter, it is important to understand the problem we are facing. A multi label deep learning classifier outputs a vector of per-class probabilities. These need to be converted to a binary vector using a confidence threshold.

  • The higher the threshold, the less classes the model predicts, increasing model confidence [higher Precision] and increasing missed classes [lower Recall].
  • The lower the threshold, the more classes the model predicts, decreasing model confidence [lower Precision] and decreasing missed classes [higher Recall].

Threshold selection mean we have to decide which metric to prioritize, based on the problem we are facing and the relative cost of misduging. We can consider the toxic comment filtering a problem similiar to cancer diagnostic. It is better to predict cancer in people who do not have it [False Positive] and perform further analysis than do not predict cancer when the patient has the disease [False Negative].

I decide to train the model on the F1 score to have a balanced model in both precision and recall and leave to the threshold selection to increase the recall performance.

Moreover, the model has been trained on the macro avarage F1 score, which is a single performance indicator obtained by the mean of the Precision and Recall scores of individual classses.

\[ F1\ macro\ avg = \frac{\sum_{i=1}^{n} F1_i}{n} \]

It is useful with imbalanced classes, because it weights each classes equally. It is not influenced by the number of samples of each classes. This is sette both in the config.metrics and find_optimal_threshold_cv.

f1_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_f1, best_score_f1 = config.find_optimal_threshold_cv(ytrue, y_pred_proba, f1_score)

print(f"Optimal threshold: {optimal_threshold_f1}")
Optimal threshold: 0.15000000000000002
Code
print(f"Best score: {best_score_f1}")
Best score: 0.4788653077945807
Code

# Use the optimal threshold to make predictions
final_predictions_f1 = (y_pred_proba >= optimal_threshold_f1).astype(int)

Optimal threshold f1 score: 0.15. Best score: 0.4788653.

recall_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_recall, best_score_recall = config.find_optimal_threshold_cv(ytrue, y_pred_proba, recall_score)

# Use the optimal threshold to make predictions
final_predictions_recall = (y_pred_proba >= optimal_threshold_recall).astype(int)

Optimal threshold recall: 0.05. Best score: 0.8095814.

roc_auc_score

Code
ytrue = ytest.astype(int)
y_pred_proba = predictions
optimal_threshold_roc, best_score_roc = config.find_optimal_threshold_cv(ytrue, y_pred_proba, roc_auc_score)

print(f"Optimal threshold: {optimal_threshold_roc}")
Optimal threshold: 0.05
Code
print(f"Best score: {best_score_roc}")
Best score: 0.8809499649742268
Code

# Use the optimal threshold to make predictions
final_predictions_roc = (y_pred_proba >= optimal_threshold_roc).astype(int)

Optimal threshold roc: 0.05. Best score: 0.88095.

5.7.3 Confusion Matrix Plot

Code
# convert probability predictions to predictions
ypred = predictions >=  optimal_threshold_recall # .05
ypred = ypred.astype(int)

# create a plot with 3 by 2 subplots
fig, axes = plt.subplots(3, 2, figsize=(15, 15))
axes = axes.flatten()
mcm = multilabel_confusion_matrix(ytrue, ypred)
# plot the confusion matrices for each label
for i, (cm, label) in enumerate(zip(mcm, config.labels)):
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot(ax=axes[i], colorbar=False)
    axes[i].set_title(f"Confusion matrix for label: {label}")
plt.tight_layout()
plt.show()
Figure 4: Multi Label Confusion matrix

5.8 Classification Report

Code
cr = classification_report(
  ytrue,
  ypred,
  target_names=config.labels,
  digits=4,
  output_dict=True
  )
df_cr = pd.DataFrame.from_dict(cr).reset_index()
Code
library(reticulate)
df_cr <- py$df_cr %>% dplyr::rename(names = index)
cols <- df_cr %>% colnames()
df_cr %>% 
  pivot_longer(
    cols = -names,
    names_to = "metrics",
    values_to = "values"
  ) %>% 
  pivot_wider(
    names_from = names,
    values_from = values
  ) %>% 
  gt() %>%
  tab_header(
    title = "Confusion Matrix",
    subtitle = "Threshold optimization favoring recall"
  ) %>% 
  fmt_number(
    columns = c("precision", "recall", "f1-score", "support"),
    decimals = 2,
    drop_trailing_zeros = TRUE,
    drop_trailing_dec_mark = FALSE
  ) %>% 
  cols_align(
    align = "center",
    columns = c("precision", "recall", "f1-score", "support")
  ) %>% 
  cols_align(
    align = "left",
    columns = metrics
  ) %>% 
  cols_label(
    metrics = "Metrics",
    precision = "Precision",
    recall = "Recall",
    `f1-score` = "F1-Score",
    support = "Support"
  )
Table 6: Classification report
Confusion Matrix
Threshold optimization favoring recall
Metrics Precision Recall F1-Score Support
toxic 0.55 0.89 0.68 2,262.
severe_toxic 0.24 0.92 0.37 240.
obscene 0.55 0.94 0.69 1,263.
threat 0.04 0.49 0.07 69.
insult 0.47 0.91 0.62 1,170.
identity_hate 0.12 0.72 0.2 207.
micro avg 0.42 0.9 0.57 5,211.
macro avg 0.33 0.81 0.44 5,211.
weighted avg 0.49 0.9 0.63 5,211.
samples avg 0.05 0.08 0.06 5,211.

6 Conclusions

The BiLSTM model is optimized to have an high recall is performing good enough to make predictions for each label. Considering the low support for the threat label, the performance is not bad. See Table 2 and Figure 1: the threat label is only 0.27 % of the observations. The model has been optimized for recall because the cost of not identifying a injurious comment as such is higher than the cost of considering a clean comment as injurious.

Possibile improvements could be to increase the number of observations, expecially for the threat one. In general there are too many clean comments. This could be avoided doing an undersampling of the clean comment, which I explicitly avoided to check the performance on the BiLSTM with an imbalanced dataset, leveraging the class weight method.